1use std::collections::{HashMap, HashSet, VecDeque};
10use std::sync::Arc;
11
12use serde_json::Value;
13use tokio::sync::mpsc;
14
15use super::blocks;
16use crate::drone::data_flow::ExecutionScope;
17use crate::drone::types::{BlockKind, BlockState, FlowEdge, FlowNode, DroneGraph};
18
19#[derive(Debug, Clone, serde::Serialize)]
21#[serde(tag = "kind", rename_all = "snake_case")]
22pub enum RunEvent {
23 RunStarted { run_id: String, drone_id: String },
24 BlockStarted { run_id: String, block_id: String },
25 BlockDone {
26 run_id: String,
27 block_id: String,
28 output: Value,
29 },
30 BlockError {
31 run_id: String,
32 block_id: String,
33 error: String,
34 },
35 RunDone {
36 run_id: String,
37 output: Value,
38 },
39 RunFailed { run_id: String, error: String },
40}
41
42pub struct RunHandle {
43 pub run_id: String,
44 pub events: mpsc::UnboundedReceiver<RunEvent>,
45 pub final_states: Arc<tokio::sync::Mutex<HashMap<String, BlockState>>>,
46}
47
48pub async fn run_drone(
52 drone_id: String,
53 graph: DroneGraph,
54) -> Result<RunHandle, String> {
55 let run_id = uuid::Uuid::new_v4().to_string();
56 let (tx, rx) = mpsc::unbounded_channel();
57 let final_states: Arc<tokio::sync::Mutex<HashMap<String, BlockState>>> =
58 Arc::new(tokio::sync::Mutex::new(HashMap::new()));
59
60 let states_for_task = final_states.clone();
61 let drone_id_for_evt = drone_id.clone();
62 let run_id_for_task = run_id.clone();
63 let tx_for_task = tx.clone();
64 tokio::spawn(async move {
65 let _ = tx_for_task.send(RunEvent::RunStarted {
66 run_id: run_id_for_task.clone(),
67 drone_id: drone_id_for_evt,
68 });
69 match execute(&run_id_for_task, &graph, &tx_for_task, &states_for_task).await {
70 Ok(output) => {
71 let _ = tx_for_task.send(RunEvent::RunDone {
72 run_id: run_id_for_task,
73 output,
74 });
75 }
76 Err(e) => {
77 let _ = tx_for_task.send(RunEvent::RunFailed {
78 run_id: run_id_for_task,
79 error: e,
80 });
81 }
82 }
83 });
84
85 Ok(RunHandle {
86 run_id,
87 events: rx,
88 final_states,
89 })
90}
91
92async fn execute(
93 run_id: &str,
94 graph: &DroneGraph,
95 tx: &mpsc::UnboundedSender<RunEvent>,
96 final_states: &Arc<tokio::sync::Mutex<HashMap<String, BlockState>>>,
97) -> Result<Value, String> {
98 if graph.nodes.is_empty() {
99 return Err("drone has no blocks".to_string());
100 }
101
102 let layers = topological_layers(graph)?;
103 let mut scope = ExecutionScope::new();
104 let mut response_output: Option<Value> = None;
105 let nodes_by_id: HashMap<String, &FlowNode> =
106 graph.nodes.iter().map(|n| (n.id.clone(), n)).collect();
107 let mut skipped: HashSet<String> = HashSet::new();
113
114 for layer in layers {
115 for block_id in layer {
116 let node = nodes_by_id
117 .get(&block_id)
118 .ok_or_else(|| format!("internal: missing node {block_id}"))?;
119 let kind = block_kind_of(node)?;
120
121 if should_skip(&block_id, graph, &nodes_by_id, &scope, &skipped) {
122 skipped.insert(block_id.clone());
123 mark_state(
124 final_states,
125 &block_id,
126 BlockState {
127 status: "skipped".to_string(),
128 output: None,
129 error: None,
130 started_at: None,
131 completed_at: Some(now_ms()),
132 },
133 )
134 .await;
135 continue;
136 }
137
138 let _ = tx.send(RunEvent::BlockStarted {
139 run_id: run_id.to_string(),
140 block_id: block_id.clone(),
141 });
142 let started_at = now_ms();
146 mark_state(
147 final_states,
148 &block_id,
149 BlockState {
150 status: "running".to_string(),
151 output: None,
152 error: None,
153 started_at: Some(started_at),
154 completed_at: None,
155 },
156 )
157 .await;
158
159 let result = match kind {
160 BlockKind::Variables => blocks::variables::run(node, &mut scope).await,
161 BlockKind::Api => blocks::api::run(node, &scope).await,
162 BlockKind::Condition => blocks::condition::run(node, &scope).await,
163 BlockKind::Response => blocks::response::run(node, &scope).await,
164 BlockKind::Agent => blocks::agent::run(node, &scope).await,
165 };
166
167 match result {
168 Ok(output) => {
169 scope.outputs.insert(block_id.clone(), output.clone());
170 let _ = tx.send(RunEvent::BlockDone {
171 run_id: run_id.to_string(),
172 block_id: block_id.clone(),
173 output: output.clone(),
174 });
175 mark_state(
176 final_states,
177 &block_id,
178 BlockState {
179 status: "done".to_string(),
180 output: Some(output.clone()),
181 error: None,
182 started_at: Some(started_at),
183 completed_at: Some(now_ms()),
184 },
185 )
186 .await;
187 if matches!(kind, BlockKind::Response) {
188 response_output = Some(
194 output
195 .get("value")
196 .cloned()
197 .unwrap_or(output),
198 );
199 }
200 }
201 Err(e) => {
202 let _ = tx.send(RunEvent::BlockError {
203 run_id: run_id.to_string(),
204 block_id: block_id.clone(),
205 error: e.clone(),
206 });
207 mark_state(
208 final_states,
209 &block_id,
210 BlockState {
211 status: "error".to_string(),
212 output: None,
213 error: Some(e.clone()),
214 started_at: Some(started_at),
215 completed_at: Some(now_ms()),
216 },
217 )
218 .await;
219 return Err(format!("block {block_id} failed: {e}"));
220 }
221 }
222 }
223 }
224
225 Ok(response_output.unwrap_or(Value::Null))
226}
227
228async fn mark_state(
229 states: &Arc<tokio::sync::Mutex<HashMap<String, BlockState>>>,
230 id: &str,
231 state: BlockState,
232) {
233 let mut g = states.lock().await;
234 g.insert(id.to_string(), state);
235}
236
237fn now_ms() -> i64 {
238 std::time::SystemTime::now()
239 .duration_since(std::time::UNIX_EPOCH)
240 .map(|d| d.as_millis() as i64)
241 .unwrap_or(0)
242}
243
244pub fn topological_layers(graph: &DroneGraph) -> Result<Vec<Vec<String>>, String> {
248 let mut indegree: HashMap<String, usize> = HashMap::new();
249 let mut outedges: HashMap<String, Vec<String>> = HashMap::new();
250 let known: HashSet<String> = graph.nodes.iter().map(|n| n.id.clone()).collect();
251 for n in &graph.nodes {
252 indegree.entry(n.id.clone()).or_insert(0);
253 outedges.entry(n.id.clone()).or_insert_with(Vec::new);
254 }
255 for e in &graph.edges {
256 if !known.contains(&e.source) || !known.contains(&e.target) {
257 return Err(format!(
258 "edge references unknown node ({} → {})",
259 e.source, e.target
260 ));
261 }
262 *indegree.entry(e.target.clone()).or_insert(0) += 1;
263 outedges
264 .entry(e.source.clone())
265 .or_insert_with(Vec::new)
266 .push(e.target.clone());
267 }
268
269 let mut layers: Vec<Vec<String>> = Vec::new();
270 let mut frontier: VecDeque<String> = indegree
271 .iter()
272 .filter_map(|(id, deg)| if *deg == 0 { Some(id.clone()) } else { None })
273 .collect();
274
275 while !frontier.is_empty() {
276 let mut layer: Vec<String> = Vec::new();
277 let next_frontier: Vec<String> = frontier.drain(..).collect();
278 for id in &next_frontier {
279 layer.push(id.clone());
280 }
281 layer.sort();
283 for id in &layer {
284 if let Some(targets) = outedges.get(id).cloned() {
285 for t in targets {
286 if let Some(d) = indegree.get_mut(&t) {
287 if *d > 0 {
288 *d -= 1;
289 if *d == 0 {
290 frontier.push_back(t);
291 }
292 }
293 }
294 }
295 }
296 }
297 layers.push(layer);
298 }
299
300 let scheduled: usize = layers.iter().map(|l| l.len()).sum();
301 if scheduled != graph.nodes.len() {
302 return Err("drone contains a cycle".to_string());
303 }
304 Ok(layers)
305}
306
307fn block_kind_of(node: &FlowNode) -> Result<BlockKind, String> {
308 let kind = node
309 .data
310 .get("kind")
311 .and_then(|v| v.as_str())
312 .ok_or_else(|| format!("node {} has no `kind` field", node.id))?;
313 BlockKind::parse(kind).ok_or_else(|| format!("unknown block kind: {kind}"))
314}
315
316fn should_skip(
328 node_id: &str,
329 graph: &DroneGraph,
330 nodes_by_id: &HashMap<String, &FlowNode>,
331 scope: &ExecutionScope,
332 skipped: &HashSet<String>,
333) -> bool {
334 let mut had_incoming = false;
335 let mut had_active = false;
336 for edge in &graph.edges {
337 if edge.target != node_id {
338 continue;
339 }
340 had_incoming = true;
341 if edge_is_active(edge, nodes_by_id, scope, skipped) {
342 had_active = true;
343 break;
344 }
345 }
346 had_incoming && !had_active
347}
348
349fn edge_is_active(
352 edge: &FlowEdge,
353 nodes_by_id: &HashMap<String, &FlowNode>,
354 scope: &ExecutionScope,
355 skipped: &HashSet<String>,
356) -> bool {
357 if skipped.contains(&edge.source) {
358 return false;
359 }
360 let src_node = match nodes_by_id.get(&edge.source) {
361 Some(n) => n,
362 None => return true,
363 };
364 let is_condition = src_node
365 .data
366 .get("kind")
367 .and_then(|v| v.as_str())
368 == Some("condition");
369 if !is_condition {
370 return true;
371 }
372 let cond_result = scope
373 .outputs
374 .get(&edge.source)
375 .and_then(|v| v.get("result"))
376 .and_then(|v| v.as_bool());
377 match (edge.source_handle.as_deref(), cond_result) {
378 (Some("true"), Some(r)) => r,
379 (Some("false"), Some(r)) => !r,
380 _ => true,
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::drone::types::{FlowEdge, FlowNode, NodePosition};
391 use serde_json::json;
392
393 fn n(id: &str, kind: &str) -> FlowNode {
394 FlowNode {
395 id: id.to_string(),
396 position: NodePosition::default(),
397 data: json!({ "kind": kind }),
398 node_type: String::new(),
399 }
400 }
401
402 fn e(id: &str, src: &str, dst: &str) -> FlowEdge {
403 FlowEdge {
404 id: id.to_string(),
405 source: src.to_string(),
406 target: dst.to_string(),
407 source_handle: None,
408 target_handle: None,
409 }
410 }
411
412 #[test]
413 fn topo_orders_diamond() {
414 let g = DroneGraph {
416 nodes: vec![n("a", "variables"), n("b", "api"), n("c", "api"), n("d", "response")],
417 edges: vec![
418 e("e1", "a", "b"),
419 e("e2", "a", "c"),
420 e("e3", "b", "d"),
421 e("e4", "c", "d"),
422 ],
423 };
424 let layers = topological_layers(&g).unwrap();
425 assert_eq!(layers.len(), 3);
426 assert_eq!(layers[0], vec!["a"]);
427 assert_eq!(layers[1], vec!["b", "c"]);
428 assert_eq!(layers[2], vec!["d"]);
429 }
430
431 #[test]
432 fn topo_rejects_cycle() {
433 let g = DroneGraph {
434 nodes: vec![n("a", "variables"), n("b", "api")],
435 edges: vec![e("e1", "a", "b"), e("e2", "b", "a")],
436 };
437 assert!(topological_layers(&g).is_err());
438 }
439
440 #[test]
441 fn topo_rejects_unknown_node_in_edge() {
442 let g = DroneGraph {
443 nodes: vec![n("a", "variables")],
444 edges: vec![e("e1", "a", "ghost")],
445 };
446 assert!(topological_layers(&g).is_err());
447 }
448
449 #[test]
450 fn topo_single_node_one_layer() {
451 let g = DroneGraph {
452 nodes: vec![n("a", "response")],
453 edges: vec![],
454 };
455 let layers = topological_layers(&g).unwrap();
456 assert_eq!(layers, vec![vec!["a".to_string()]]);
457 }
458
459 fn eh(id: &str, src: &str, dst: &str, handle: Option<&str>) -> FlowEdge {
464 FlowEdge {
465 id: id.to_string(),
466 source: src.to_string(),
467 target: dst.to_string(),
468 source_handle: handle.map(|s| s.to_string()),
469 target_handle: None,
470 }
471 }
472
473 fn nodes_map<'a>(g: &'a DroneGraph) -> HashMap<String, &'a FlowNode> {
474 g.nodes.iter().map(|n| (n.id.clone(), n)).collect()
475 }
476
477 fn scope_with_cond(cond_id: &str, result: bool) -> ExecutionScope {
478 let mut scope = ExecutionScope::new();
479 scope
480 .outputs
481 .insert(cond_id.to_string(), json!({ "result": result }));
482 scope
483 }
484
485 #[test]
486 fn skip_root_node_runs() {
487 let g = DroneGraph {
489 nodes: vec![n("a", "variables")],
490 edges: vec![],
491 };
492 let nodes = nodes_map(&g);
493 assert!(!should_skip(
494 "a",
495 &g,
496 &nodes,
497 &ExecutionScope::new(),
498 &HashSet::new()
499 ));
500 }
501
502 #[test]
503 fn skip_unconditional_chain_runs() {
504 let g = DroneGraph {
506 nodes: vec![n("a", "variables"), n("b", "api")],
507 edges: vec![e("e1", "a", "b")],
508 };
509 let nodes = nodes_map(&g);
510 assert!(!should_skip(
511 "b",
512 &g,
513 &nodes,
514 &ExecutionScope::new(),
515 &HashSet::new()
516 ));
517 }
518
519 #[test]
520 fn skip_condition_false_branch_when_result_true() {
521 let g = DroneGraph {
524 nodes: vec![n("c", "condition"), n("t", "api"), n("f", "api")],
525 edges: vec![
526 eh("e1", "c", "t", Some("true")),
527 eh("e2", "c", "f", Some("false")),
528 ],
529 };
530 let nodes = nodes_map(&g);
531 let scope = scope_with_cond("c", true);
532 let skipped = HashSet::new();
533 assert!(!should_skip("t", &g, &nodes, &scope, &skipped));
534 assert!(should_skip("f", &g, &nodes, &scope, &skipped));
535 }
536
537 #[test]
538 fn skip_condition_true_branch_when_result_false() {
539 let g = DroneGraph {
540 nodes: vec![n("c", "condition"), n("t", "api"), n("f", "api")],
541 edges: vec![
542 eh("e1", "c", "t", Some("true")),
543 eh("e2", "c", "f", Some("false")),
544 ],
545 };
546 let nodes = nodes_map(&g);
547 let scope = scope_with_cond("c", false);
548 let skipped = HashSet::new();
549 assert!(should_skip("t", &g, &nodes, &scope, &skipped));
550 assert!(!should_skip("f", &g, &nodes, &scope, &skipped));
551 }
552
553 #[test]
554 fn skip_transitive_through_skipped_source() {
555 let g = DroneGraph {
559 nodes: vec![n("c", "condition"), n("f", "api"), n("x", "agent")],
560 edges: vec![
561 eh("e1", "c", "f", Some("false")),
562 e("e2", "f", "x"),
563 ],
564 };
565 let nodes = nodes_map(&g);
566 let scope = scope_with_cond("c", true);
567 let mut skipped = HashSet::new();
568 skipped.insert("f".to_string());
569 assert!(should_skip("x", &g, &nodes, &scope, &skipped));
570 }
571
572 #[test]
573 fn join_runs_if_any_incoming_active() {
574 let g = DroneGraph {
577 nodes: vec![n("a", "variables"), n("b", "api"), n("d", "response")],
578 edges: vec![e("e1", "a", "d"), e("e2", "b", "d")],
579 };
580 let nodes = nodes_map(&g);
581 let mut skipped = HashSet::new();
582 skipped.insert("b".to_string());
583 assert!(!should_skip(
584 "d",
585 &g,
586 &nodes,
587 &ExecutionScope::new(),
588 &skipped
589 ));
590 }
591
592 #[test]
593 fn condition_edge_without_handle_is_permissive() {
594 let g = DroneGraph {
598 nodes: vec![n("c", "condition"), n("x", "api")],
599 edges: vec![eh("e1", "c", "x", None)],
600 };
601 let nodes = nodes_map(&g);
602 let scope = scope_with_cond("c", false);
603 assert!(!should_skip("x", &g, &nodes, &scope, &HashSet::new()));
604 }
605
606 #[tokio::test]
607 async fn execute_preserves_started_at_in_block_state() {
608 let mut vars_node = n("v1", "variables");
613 vars_node.data = json!({
619 "kind": "variables",
620 "entries": [{ "name": "v", "value": 1 }]
621 });
622 let mut resp_node = n("r1", "response");
623 resp_node.data = json!({ "kind": "response", "template": "done" });
624 let g = DroneGraph {
625 nodes: vec![vars_node, resp_node],
626 edges: vec![e("e1", "v1", "r1")],
627 };
628
629 let handle = run_drone("wf1".to_string(), g).await.unwrap();
630 let mut rx = handle.events;
631 while let Some(ev) = rx.recv().await {
632 if matches!(ev, RunEvent::RunDone { .. } | RunEvent::RunFailed { .. }) {
633 break;
634 }
635 }
636
637 let states = handle.final_states.lock().await;
638 let v1 = states.get("v1").expect("v1 state");
639 let r1 = states.get("r1").expect("r1 state");
640 assert_eq!(v1.status, "done");
641 assert!(
642 v1.started_at.is_some() && v1.started_at.unwrap() > 0,
643 "v1 started_at must survive the done transition; got {:?}",
644 v1.started_at
645 );
646 assert!(v1.completed_at.is_some());
647 assert!(v1.completed_at.unwrap() >= v1.started_at.unwrap());
648 assert!(r1.started_at.is_some());
649 assert!(r1.completed_at.is_some());
650 }
651
652 #[test]
653 fn flow_edge_serializes_with_camelcase_handle_fields() {
654 let edge = FlowEdge {
659 id: "e1".to_string(),
660 source: "a".to_string(),
661 target: "b".to_string(),
662 source_handle: Some("true".to_string()),
663 target_handle: None,
664 };
665 let json = serde_json::to_string(&edge).unwrap();
666 assert!(
667 json.contains("\"sourceHandle\":\"true\""),
668 "expected sourceHandle in JSON; got {json}"
669 );
670 let parsed: FlowEdge = serde_json::from_str(&json).unwrap();
672 assert_eq!(parsed.source_handle.as_deref(), Some("true"));
673 }
674
675 #[tokio::test]
676 async fn execute_skips_pruned_branch_end_to_end() {
677 let mut vars_node = n("v1", "variables");
688 vars_node.data = json!({
689 "kind": "variables",
690 "entries": [{ "name": "v", "value": 10 }]
691 });
692 let mut cond_node = n("c1", "condition");
693 cond_node.data = json!({
694 "kind": "condition",
695 "expr": "{{var.v}} < 5"
696 });
697 let mut t_resp = n("rt", "response");
698 t_resp.data = json!({
699 "kind": "response",
700 "template": "hit_true"
701 });
702 let mut f_resp = n("rf", "response");
703 f_resp.data = json!({
704 "kind": "response",
705 "template": "hit_false"
706 });
707 let g = DroneGraph {
708 nodes: vec![vars_node, cond_node, t_resp, f_resp],
709 edges: vec![
710 e("e1", "v1", "c1"),
711 eh("e2", "c1", "rt", Some("true")),
712 eh("e3", "c1", "rf", Some("false")),
713 ],
714 };
715
716 let handle = run_drone("wf1".to_string(), g).await.unwrap();
717 let mut rx = handle.events;
719 let mut got_done_ids: Vec<String> = Vec::new();
720 let mut final_output: Option<Value> = None;
721 while let Some(ev) = rx.recv().await {
722 match ev {
723 RunEvent::BlockDone { block_id, .. } => got_done_ids.push(block_id),
724 RunEvent::RunDone { output, .. } => {
725 final_output = Some(output);
726 break;
727 }
728 RunEvent::RunFailed { error, .. } => panic!("run failed: {error}"),
729 _ => {}
730 }
731 }
732
733 assert!(got_done_ids.contains(&"rf".to_string()));
735 assert!(
736 !got_done_ids.contains(&"rt".to_string()),
737 "true branch must NOT run when condition is false; got: {got_done_ids:?}"
738 );
739
740 let states = handle.final_states.lock().await;
742 let rt_state = states.get("rt").expect("rt state recorded");
743 assert_eq!(rt_state.status, "skipped");
744
745 assert_eq!(final_output, Some(json!("hit_false")));
749 }
750}